Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Target Platform Capabilities Design #1276

Merged
merged 4 commits into from
Dec 2, 2024
Merged

Conversation

lior-dikstein
Copy link
Collaborator

@lior-dikstein lior-dikstein commented Nov 24, 2024

  • Create a new schema package to house all target platform modeling classes
  • Introduce a new versioning system with minor and patch versions

Additional Changes:

  • Update existing target platform models to adhere to the new versioning convention
  • Add necessary metadata
  • Correct all import statements
  • Update and enhance tests to reflect the design changes

Pull Request Description:

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

- Create a new `schema` package to house all target platform modeling classes
- Introduce a new versioning system with minor and patch versions

Additional Changes:
- Update existing target platform models to adhere to the new versioning convention
- Add necessary metadata
- Correct all import statements
- Update and enhance tests to reflect the design changes
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
TargetPlatformCapabilities, LayerFilterParams, OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OpQuantizationConfig, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the schema is going to be updated (maybe frequently), and MCT core would change the schema that it is using accordingly, we need to figure out how to modify these imports throughout the code without accessing ".v1" directly at each import.

Is there a way to "export" (as TPC package "API") a default schema that references the currently used version, such that all imports will point to it and when we want to change the used schema by MCT we'll only have to change in 1 place?

Maybe @irenaby would have an idea how it can be done?

Let's discuss this offline if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this is not mandatory for this PR, but solving this here would be better, because it will save us editing all these files again in a separate PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only part of MCT that should be aware of the schema is the parser once we have it. It should parse the schema into whatever representation the rest of mct works with. If we want to reuse the same classes for now, we can add a proxy module that will only import the classes from the schema, and the rest of mct imports from that proxy module.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @irenaby , we should have some proxy model at this stage...

@@ -102,7 +102,7 @@ def get_op_quantization_configs() -> Tuple[OpQuantizationConfig, List[OpQuantiza
signedness=Signedness.AUTO)

# We define an 8-bit config for linear operations quantization, that include a kernel and bias attributes.
linear_eight_bits = tp.OpQuantizationConfig(
linear_eight_bits = model_compression_toolkit.target_platform_capabilities.schema.v1.OpQuantizationConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe import the path as "schema_v1" (or create an alias at the beginning of the file) instead of repeating it everywhere?
or do we want it to be explicit in the TPC model description? @haihabi
even if we want it to, eventually this file will be a JSON and not written as code using imports, so these fields in the JSON won't include the entire schema path anyway...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the JSON file should have the schema version from the constant in the schema main class.

Copy link
Collaborator

@ofirgo ofirgo Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the JSON file should have the schema version from the constant in the schema main class.

So you agree? the path here can be shortened with an alias to improve file readability?
@haihabi

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

tests/common_tests/test_tp_model.py Outdated Show resolved Hide resolved
tpc_patch_version = f'{tpc.tp_model.tpc_patch_version}'
tpc_platform_type = f'{tpc.tp_model.tpc_platform_type}'
tpc_schema = f'{tpc.tp_model.SCHEMA_VERSION}'
return {MCT_VERSION: mct_version,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need to access the fields elsewhere, why not define class or named tuple? If this goes directly into the model, no need to define global consts.

1. Fixed imports from long strings to alias "schema".
2. Fixed some failing tests.
3. Removed unused files.
1. Created mct_current_schema.py for one place to update the schema version, and replaced all the imports in mct to work with this location.
2. Replaced metadata dictionary with dataclass.
@@ -5,7 +5,7 @@
Several training methods may be applied by the user to train the QAT ready model
created by `keras_quantization_aware_training_init` method in [`keras/quantization_facade`](../quantization_facade.py).
Each `TrainingMethod` (an enum defined in the [`qat_config`](../../common/qat_config.py))
and [`QuantizationMethod`](../../../target_platform_capabilities/target_platform/op_quantization_config.py)
and `QuantizationMethod`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix the path according to the new file location

@lior-dikstein lior-dikstein merged commit 28b461b into main Dec 2, 2024
42 checks passed
@lior-dikstein lior-dikstein deleted the tpc_version branch December 2, 2024 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants